import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import statsmodels.formula.api as sm

from cat_rfs.code.utils.tools import save_stat
from cat_rfs.code.utils.graph import convert_to_percentages, add_nber_recessions


def print_fig2(output='console'):
    """
    Print Figure 2: Cumulative excess returns on selected indices.
    Contains return time series for Swiss Re Global Cat Bond Total Return Index, Eurekahedge ILS Advisers index,
    CRSP value weighted index, and Bloomberg Barclays U.S. Corporate High Yield Total Return Index.
    """
    df = pd.read_csv('cat_rfs/data/cat_vs_others.csv', parse_dates=['mdate'])

    for var in ('hy', 'swissre'):  # Convert from index to return
        df[f'ret_{var}'] = df[f'ix_{var}'] / df[f'ix_{var}'].shift(1) - 1

    for var in ('mmret', 'ret_hy', 'ret_swissre'):  # Convert to cumulative return
        df[var + '_cum'] = df[var].cumsum()

    # Make ILS Advisers index to start at Cat Bond Total Return Index level in 2025 and convert to cumulative
    df['ilsfund_ret2'] = df['ilsfund_ret'].copy()
    df.loc[df['mdate'] == '12/31/2005', 'ilsfund_ret2'] = df.loc[df['mdate'] == '12/31/2005', 'ret_swissre_cum']
    df['ilsfund_ret_cum'] = df['ilsfund_ret2'].cumsum()

    # Run analyses and print graph (Panel A)
    df = df.set_index('mdate')
    ols = sm.ols('ilsfund_ret ~ ret_swissre', data=df).fit()

    save_stat(ols.rsquared, 'r2_ils_swissre', formatting='%.2f', output=output)
    save_stat(ols.params[1], 'beta_ils_swissre', formatting='%.2f', output=output)

    cp = sns.color_palette()
    fig, ax = plt.subplots(figsize=(6.5, 3.3))

    ax.plot(df[['ret_swissre_cum', 'ilsfund_ret_cum']])

    ax.set_xlabel('')
    ax.legend(['$R^e_{cat}$',
               '$R^e_{cat\,manager}$'], loc='lower right')
    ax.set_xlim(df.index.min(), df.index.max())
    ax = convert_to_percentages(ax, axis='y', digits=0)

    ax.text(pd.to_datetime('6/1/2002'), 1.16, r'$R^2_{cat\,manager,cat}=$ ' + str(ols.rsquared.round(2)))
    ax.text(pd.to_datetime('6/1/2002'), 1.04, r'$\beta_{cat\,manager,cat}=$ ' + str(ols.params[1].round(2)))
    add_nber_recessions(ax, color='lightgray')

    ax.axvline(x=pd.to_datetime('8/31/2005'), ymax=0.70, ls='--', c='.3')  # Katrina
    ax.axvline(x=pd.to_datetime('8/31/2005'), ymin=0.93, ls='--', c='.3')

    ax.axvline(x=pd.to_datetime('9/30/2008'), ls='--', c='.3')  # Ike
    ax.axvline(x=pd.to_datetime('3/31/2011'), ls='--', c='.3')  # Tohoku
    ax.axvline(x=pd.to_datetime('10/31/2012'), ls='--', c='.3')  # Sandy
    ax.axvline(x=pd.to_datetime('9/30/2017'), ls='--', c='.3')  # HIM
    ax.axvline(x=pd.to_datetime('10/31/2018'), ls='--', c='.3')  # Michael
    ax.axvline(x=pd.to_datetime('9/30/2021'), ls='--', c='.3')  # Hurricane Ida
    ax.axvline(x=pd.to_datetime('9/30/2022'), ls='--', c='.3')  # Hurricane Ian

    if output == 'paper':
        plt.savefig('cat_rfs/output/figures/fig2_a.eps')
        plt.close()

    # Run analyses and print graph (Panel B)
    olse = sm.ols('ret_swissre ~ mmret', data=df).fit(use_t=True, cov_type='HAC', cov_kwds={'maxlags': 0})
    olsb = sm.ols('ret_swissre ~ ret_hy', data=df).fit(use_t=True, cov_type='HAC', cov_kwds={'maxlags': 0})
    olsbe = sm.ols('ret_hy ~ mmret', data=df).fit(use_t=True)

    save_stat(olse.rsquared, 'r2_swissre_eq', formatting='%.2f', output=output)
    save_stat(olse.params[1], 'beta_swissre_eq', formatting='%.2f', output=output)
    save_stat(olsb.rsquared, 'r2_swissre_hy', formatting='%.2f', output=output)
    save_stat(olsb.params[1], 'beta_swissre_hy', formatting='%.2f', output=output)
    save_stat(olsbe.rsquared, 'r2_hy_eq', formatting='%.2f', output=output)
    save_stat(olsbe.params[1], 'beta_hy_eq', formatting='%.2f', output=output)

    fig, ax = plt.subplots(figsize=(6.5, 3.3))

    ax.plot(df[['ret_swissre_cum', 'mmret_cum', 'ret_hy_cum']])

    ax.set_xlabel('')
    ax.legend(['$R^e_{cat}$',
               '$R^e_{equity}$',
               '$R^e_{hy\,bonds}$'], loc='lower right')
    ax.set_xlim(df.index.min(), df.index.max())
    ax = convert_to_percentages(ax, axis='y', digits=0)

    ax.text(pd.to_datetime('6/1/2002'), 1.88, r'$R^2_{cat, equity}=$ ' + str(olse.rsquared.round(2)))
    ax.text(pd.to_datetime('6/1/2002'), 1.67, r'$\beta_{cat,equity}=$ ' + str(olse.params[1].round(2)))

    ax.text(pd.to_datetime('6/1/2002'), 1.46, r'$R^2_{cat, hy\,bonds}=$ ' + str(olsb.rsquared.round(2)))
    ax.text(pd.to_datetime('6/1/2002'), 1.25, r'$\beta_{cat, hy\,bonds}=$ ' + str(olsb.params[1].round(2)))

    add_nber_recessions(ax, color='lightgray')

    ax.axvline(x=pd.to_datetime('8/31/2005'), ymax=0.60, ls='--', c='.3')  # Katrina
    ax.axvline(x=pd.to_datetime('8/31/2005'), ymin=0.95, ls='--', c='.3')

    ax.axvline(x=pd.to_datetime('9/30/2008'), ls='--', c='.3')  # Ike
    ax.axvline(x=pd.to_datetime('3/31/2011'), ls='--', c='.3')  # Tohoku
    ax.axvline(x=pd.to_datetime('10/31/2012'), ls='--', c='.3')  # Sandy
    ax.axvline(x=pd.to_datetime('9/30/2017'), ls='--', c='.3')  # HIM
    ax.axvline(x=pd.to_datetime('10/31/2018'), ls='--', c='.3')  # Michael
    ax.axvline(x=pd.to_datetime('9/30/2021'), ls='--', c='.3')  # Hurricane Ida
    ax.axvline(x=pd.to_datetime('9/30/2022'), ls='--', c='.3')  # Hurricane Ian

    if output == 'paper':
        plt.savefig('cat_rfs/output/figures/fig2_b.eps')
        plt.close()
    else:
        plt.show()

    pass
